﻿using System;
using System.IO;
using Vanara.InteropServices;
using static Vanara.PInvoke.AMSI;

namespace Vanara.PInvoke.Diagnostics
{
	/// <summary>The <c>ScanResult</c> enumeration specifies the types of results returned by scans.</summary>
	public enum ScanResult : uint
	{
		/// <summary>Known good. No detection found, and the result is likely not going to change after a future definition update.</summary>
		Clean = AMSI_RESULT.AMSI_RESULT_CLEAN,

		/// <summary>No detection found, but the result might change after a future definition update.</summary>
		NotDetected = AMSI_RESULT.AMSI_RESULT_NOT_DETECTED,

		/// <summary>A threat level less than the max was found, so there is a potential that the content is considered malware.</summary>
		PotentialDetected = AMSI_RESULT.AMSI_RESULT_NOT_DETECTED + 1,

		/// <summary>Detection found. The content is considered malware and should be blocked.</summary>
		Detected = AMSI_RESULT.AMSI_RESULT_DETECTED,
	}

	/// <summary>Provides scanning of strings and buffers to detect malware using either the system provider or a custom provider.</summary>
	public static class AntimalwareScan
	{
		private static SafeHAMSICONTEXT hCtx;

		/// <summary>
		/// Gets or sets the provider to use for Antimalware scans. If <see langword="null"/>, the system default provider is used.
		/// </summary>
		/// <value>The Antimalware scan provider.</value>
		public static IAntimalwareProvider Provider { get; set; }

		/// <summary>Scans a buffer-full of content for malware.</summary>
		/// <param name="buffer">The buffer from which to read the data to be scanned.</param>
		/// <param name="contentName">The filename, URL, unique script ID, or similar of the content being scanned.</param>
		/// <returns>The result of the scan.</returns>
		public static ScanResult Scan(byte[] buffer, string contentName = null)
		{
			unsafe
			{
				fixed (byte* bufferPtr = buffer)
				{
					return Scan((IntPtr)bufferPtr, (uint)buffer.Length, contentName);
				}
			}
		}

		/// <summary>Scans a buffer-full of content for malware.</summary>
		/// <param name="buffer">The buffer from which to read the data to be scanned.</param>
		/// <param name="bufferLen">The length, in bytes, of the data to be read from <c>buffer</c>.</param>
		/// <param name="contentName">The filename, URL, unique script ID, or similar of the content being scanned.</param>
		/// <returns>The result of the scan.</returns>
		public static ScanResult Scan(IntPtr buffer, uint bufferLen, string contentName = null)
		{
			AMSI_RESULT result;
			if (Provider is null)
			{
				EnsureContext();
				using SafeHAMSISESSION session = new(hCtx);
				AmsiScanBuffer(session.Context, buffer, bufferLen, contentName, session, out result).ThrowIfFailed();
				return result.Convert();
			}
			else
			{
				using AmsiStream stream = new AmsiStream(new SafeCoTaskMemHandle(buffer, bufferLen, false), false);
				Provider.Scan(stream, out result).ThrowIfFailed();
			}
			return result.Convert();
		}

		/// <summary>Scans a string for malware.</summary>
		/// <param name="str">The string to be scanned.</param>
		/// <param name="contentName">The filename, URL, unique script ID, or similar of the content being scanned.</param>
		/// <returns>The result of the scan.</returns>
		public static ScanResult Scan(string str, string contentName = null)
		{
			AMSI_RESULT result;
			if (Provider is null)
			{
				EnsureContext();
				using SafeHAMSISESSION session = new(hCtx);
				AmsiScanString(session.Context, str, contentName, session, out result).ThrowIfFailed();
				return result.Convert();
			}
			else
			{
				using AmsiStream stream = new AmsiStream(new SafeCoTaskMemString(str), false);
				Provider.Scan(stream, out result).ThrowIfFailed();
			}
			return result.Convert();
		}

		/// <summary>Scans a file for malware.</summary>
		/// <param name="file">The file from which to read the data to be scanned.</param>
		/// <returns>The result of the scan.</returns>
		public static ScanResult Scan(FileInfo file) => Scan(File.ReadAllBytes(file.FullName), file.FullName);

		private static ScanResult Convert(this AMSI_RESULT result) => result switch
		{
			AMSI_RESULT.AMSI_RESULT_CLEAN => ScanResult.Clean,
			AMSI_RESULT.AMSI_RESULT_NOT_DETECTED => ScanResult.NotDetected,
			>= AMSI_RESULT.AMSI_RESULT_DETECTED => ScanResult.Detected,
			_ => ScanResult.PotentialDetected,
		};

		private static void EnsureContext()
		{
			if (hCtx is null || hCtx.IsInvalid)
			{
				AmsiInitialize(Guid.NewGuid().ToString(), out hCtx).ThrowIfFailed();
			}
		}
	}
}